//给定两个大小为 m 和 n 的正序（从小到大）数组 nums1 和 nums2。 
//
// 请你找出这两个正序数组的中位数，并且要求算法的时间复杂度为 O(log(m + n))。 
//
// 你可以假设 nums1 和 nums2 不会同时为空。 
//
// 
//
// 示例 1: 
//
// nums1 = [1, 3]
//nums2 = [2]
//
//则中位数是 2.0
// 
//
// 示例 2: 
//
// nums1 = [1, 2]
//nums2 = [3, 4]
//
//则中位数是 (2 + 3)/2 = 2.5
// 
// Related Topics 数组 二分查找 分治算法

package leetcode.editor.cn;

import org.junit.Assert;
import org.junit.Test;

public class MedianOfTwoSortedArrays {
    public static void main(String[] args) {
        Solution solution = new MedianOfTwoSortedArrays().new Solution();
        S1 s1 = new S1();
//        double v = s1.find(new int[]{1, 3}, new int[]{2, 4});
        double v = solution.findMedianSortedArrays(null, new int[]{2});
        System.out.println(v);
    }

    @Test
    public void test1() {
        Solution solution = new MedianOfTwoSortedArrays().new Solution();
        double v = solution.findMedianSortedArrays(null, new int[]{2});
        Assert.assertEquals(2.0,v,0.1);

        v = solution.findMedianSortedArrays(new int[]{1},null);
        Assert.assertEquals(1.0,v,0.1);

        v = solution.findMedianSortedArrays(new int[]{1,2},new int[]{3,4});
        Assert.assertEquals(2.5,v,0.1);

        v = solution.findMedianSortedArrays(new int[]{1,2},new int[]{3});
        Assert.assertEquals(2.0,v,0.1);

        v = solution.findMedianSortedArrays(new int[]{1,2},new int[]{1,3});
        Assert.assertEquals(1.5,v,0.1);

        v = solution.findMedianSortedArrays(new int[]{-2,-1},new int[]{-4,-3});
        Assert.assertEquals(-2.5,v,0.1);

        v = solution.findMedianSortedArrays(new int[]{3},new int[]{-2,-1});
        Assert.assertEquals(-2.5,v,0.1);
    }
    //leetcode submit region begin(Prohibit modification and deletion)
class Solution {
    public double findMedianSortedArrays(int[] nums1, int[] nums2) {
        int length1;
        int length2;
        if (nums1 == null) {
            length1 = 0;
        }else {
            length1 = nums1.length;
        }
        if (nums2 == null) {
            length2 = 0;
        } else {
            length2 = nums2.length;
        }

        if (length1 == 0) {
            double result;
            if (length2 % 2 == 0) {
                result = (nums2[length2 / 2 - 1] + nums2[length2 / 2]) / 2.0;
            } else {
                result = nums2[length2/2];
            }
            return result;
        }

        if (length2 == 0) {
            double result;
            if (length1 % 2 == 0) {
                result = (nums1[length1 / 2 - 1] + nums1[length1 / 2]) / 2.0;
            } else {
                result = nums1[length1/2];
            }
            return result;
        }


        /**
         * 将较小的放到前面
         */
        if (nums1[length1 - 1] <= nums2[0]) {

        } else if (nums2[length2 - 1] <= nums1[0]) {
            int[] temp = nums1;
            nums1 = nums2;
            nums2 = temp;

            length1 = nums1.length;
            length2 = nums2.length;
        }

        /**
         * 如果num1数组，整体小于nums2，则直接先合并nums1到目标数组，再合并数组2
         */
        int finalLength = length1 + length2;
        if (nums1[length1 - 1] <= nums2[0]) {
            int length = finalLength;
            double result;
            int rightWhenOdd = length / 2;
            int leftWhenOdd = rightWhenOdd - 1;

            if (length % 2 == 0) {
                int temp1;
                if (leftWhenOdd < length1) {
                    temp1 = nums1[leftWhenOdd];
                } else {
                    temp1 = nums2[leftWhenOdd - length1];
                }

                int temp2;
                if (rightWhenOdd < length1) {
                    temp2 = nums1[rightWhenOdd];
                } else {
                    temp2 = nums2[rightWhenOdd - length1];
                }

                result = (temp1 + temp2) / 2.0;
                return result;
            } else {
                if (rightWhenOdd < length1) {
                    result = nums1[rightWhenOdd];
                }else {
                    result = nums2[rightWhenOdd - length1];
                }
                return result;
            }
        }


        int[] merged = new int[finalLength];

        int indexOfArray1 = 0;
        int indexOfArray2 = 0;
        int targetArrayIndex = 0;

        for (; targetArrayIndex < finalLength; ) {
            // 数组1已经走完了
            if (indexOfArray1 == length1) {
                while (indexOfArray2 < length2) {
                    merged[targetArrayIndex++] = nums2[indexOfArray2++];
                }
                break;
            }

            //数组2已经走完了
            if (indexOfArray2 == length2) {
                while (indexOfArray1 < length1) {
                    merged[targetArrayIndex++] = nums1[indexOfArray1++];
                }
                break;
            }

            int elementInArray1 = nums1[indexOfArray1];

            int elementInArray2 = nums2[indexOfArray2];

            if (elementInArray1 <= elementInArray2) {
                merged[targetArrayIndex++] = elementInArray1;
                indexOfArray1++;
                continue;
            } else {
                merged[targetArrayIndex++] = elementInArray2;
                indexOfArray2++;
                continue;
            }

        }

        double result;
        if (finalLength % 2 == 0) {
            result = (merged[finalLength / 2 - 1] + merged[finalLength / 2]) / 2.0;
        } else {
            result = merged[finalLength/2];
        }

        return result;
    }


}
//leetcode submit region end(Prohibit modification and deletion)

static class S1{
    double  find(int[] nums1, int[] nums2){
        int[] nums;
        int m = nums1.length;
        int n = nums2.length;
        nums = new int[m + n];
        if (m == 0) {
            if (n % 2 == 0) {
                return (nums2[n / 2 - 1] + nums2[n / 2]) / 2.0;
            } else {

                return nums2[n / 2];
            }
        }
        if (n == 0) {
            if (m % 2 == 0) {
                return (nums1[m / 2 - 1] + nums1[m / 2]) / 2.0;
            } else {
                return nums1[m / 2];
            }
        }

        int count = 0;
        int i = 0, j = 0;
        while (count != (m + n)) {
            if (i == m) {
                while (j != n) {
                    nums[count++] = nums2[j++];
                }
                break;
            }
            if (j == n) {
                while (i != m) {
                    nums[count++] = nums1[i++];
                }
                break;
            }

            if (nums1[i] < nums2[j]) {
                nums[count++] = nums1[i++];
            } else {
                nums[count++] = nums2[j++];
            }
        }

        if (count % 2 == 0) {
            return (nums[count / 2 - 1] + nums[count / 2]) / 2.0;
        } else {
            return nums[count / 2];
        }
    }
}

}